Amazon SageMakerがTensorFlow 2.0に対応しました

Amazon SageMakerがTensorFlow 2.0に対応しました

Clock Icon2020.01.22

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

Amazon SageMaker を使うことでインフラの構築や管理をすることなく、機械学習モデルの学習やホスティングが可能となります。 また、SageMakerではTensorFlowやPyTorch、MXNetといった深層学習フレームワークに対応しており、モデルの学習用スクリプトやデータの場所等を指定することでSageMakerが用意したコンテナ上でモデルの学習を実行させることができます。

そんなSageMakerでTensorFlow 2.0が使えるようになりました!

これまでも専用のコンテナイメージを作成することで、TensorFlow 2.0をSageMakerで使うことは可能でした。今回のアップデートによって、TensorFlow 2.0がSageMakerでコンテナイメージの作成の必要なく、ネイティブで使えるようになりました。便利です。

やってみる

早速、TensorFlow 2.0を利用したスクリプトでMNISTの分類モデルの学習を以下のノートブックを元に試してみたいと思います。このノートブックでは学習に使用するデータが既にパブリックアクセス可能なS3に保存されているため、SageMakerでできることを手軽に体験できます。

SageMakerでTensorFlow 2.0を使うために必要なのは、framework_version2.0.0を指定するだけです。それ以外はこれまでと同様です。

entry_pointとして設定してあるmnist-2.pyはTensorFlow 2.0で書かれた学習用スクリプトです。mnist_estimator2.fitを叩くと、SageMaker上でそのスクリプトが実行されます。

import os
import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()
role = get_execution_role()
region = sagemaker_session.boto_session.region_name

# 学習用と評価用データの保存場所
training_data_uri = 's3://sagemaker-sample-data-{}/tensorflow/mnist'.format(region)

# SageMakerでの学習をハンドルするTensorFlow用Estimator
mnist_estimator2 = TensorFlow(entry_point='mnist-2.py',
                             role=role,
                             train_instance_count=2,
                             train_instance_type='ml.m5.xlarge',
                             framework_version='2.0.0', # 使用するTensorFlowのバージョン
                             py_version='py3',
                             distributions={'parameter_server': {'enabled': True}})

# 学習
mnist_estimator2.fit(training_data_uri)

....

マネジメントコンソールからも学習の結果を参照できます。 763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/tensorflow-training:2.0.0-cpu-py3が学習用のコンテナイメージとして使われていることがわかります。

モデルのホスティング

モデルのホスティングや推論もこれまでと同様にすることができます。 Estimatorからdeployを叩くことでモデルをデプロイし、ホスティングします。

predictor2 = mnist_estimator2.deploy(initial_instance_count=1, instance_type='ml.m5.xlarge')

数分でデプロイが完了します。

推論に使用するデータをダウンロードします。

import numpy as np

!aws --region {region} s3 cp s3://sagemaker-sample-data-{region}/tensorflow/mnist/train_data.npy train_data.npy
!aws --region {region} s3 cp s3://sagemaker-sample-data-{region}/tensorflow/mnist/train_labels.npy train_labels.npy

train_data = np.load('train_data.npy')
train_labels = np.load('train_labels.npy')

Predictorでpredictを叩くことで指定したデータを推論することができます。

predictions2 = predictor2.predict(train_data[:50])
for i in range(0, 50):
    prediction = predictions2['predictions'][i]
    label = train_labels[i]
    print('prediction is {}, label is {}, matched: {}'.format(prediction, label, prediction == label))

エンドポイント名を指定して、エンドポイントを削除します。

sagemaker.Session().delete_endpoint(predictor2.endpoint)

さいごに

Amazon SageMakerでTensorFlow 2.0を使う方法について紹介しました。TensorFlow 2.0によってパフォーマンスが向上したり、シンプルにスクリプトが書けるようになったりと様々なメリットがありますが、SageMakerでもそれらを享受できます。活用の幅がより広がりそうです。

参考

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.